/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.sdk.io.gcp.bigquery;



import com.google.api.services.bigquery.model.*;

import com.bff.gaia.unified.sdk.coders.KvCoder;

import com.bff.gaia.unified.sdk.coders.StringUtf8Coder;

import com.bff.gaia.unified.sdk.coders.VoidCoder;

import com.bff.gaia.unified.sdk.io.FileSystems;

import com.bff.gaia.unified.sdk.io.fs.ResourceId;

import com.bff.gaia.unified.sdk.options.ValueProvider;

import com.bff.gaia.unified.sdk.transforms.*;

import com.bff.gaia.unified.sdk.transforms.windowing.*;

import com.bff.gaia.unified.sdk.values.*;

import com.bff.gaia.unified.vendor.guava.com.google.common.base.Strings;

import com.bff.gaia.unified.vendor.guava.com.google.common.collect.ImmutableList;

import com.bff.gaia.unified.vendor.guava.com.google.common.collect.Lists;

import com.bff.gaia.unified.vendor.guava.com.google.common.collect.Maps;

import com.bff.gaia.unified.sdk.transforms.DoFn;

import com.bff.gaia.unified.sdk.transforms.GroupByKey;

import com.bff.gaia.unified.sdk.transforms.PTransform;

import com.bff.gaia.unified.sdk.transforms.ParDo;

import com.bff.gaia.unified.sdk.transforms.Values;

import com.bff.gaia.unified.sdk.transforms.WithKeys;

import com.bff.gaia.unified.sdk.transforms.windowing.AfterPane;

import com.bff.gaia.unified.sdk.transforms.windowing.BoundedWindow;

import com.bff.gaia.unified.sdk.transforms.windowing.GlobalWindows;

import com.bff.gaia.unified.sdk.transforms.windowing.Repeatedly;

import com.bff.gaia.unified.sdk.transforms.windowing.Window;

import com.bff.gaia.unified.sdk.values.KV;

import com.bff.gaia.unified.sdk.values.PCollection;

import com.bff.gaia.unified.sdk.values.PCollectionTuple;

import com.bff.gaia.unified.sdk.values.PCollectionView;

import com.bff.gaia.unified.sdk.values.ShardedKey;

import com.bff.gaia.unified.sdk.values.TupleTag;

import com.bff.gaia.unified.sdk.values.TupleTagList;

import org.slf4j.Logger;

import org.slf4j.LoggerFactory;



import javax.annotation.Nullable;

import java.io.IOException;

import java.util.List;

import java.util.Map;



import static com.bff.gaia.unified.vendor.guava.com.google.common.base.Preconditions.checkArgument;



/**

 * Writes partitions to BigQuery tables.

 *

 * <p>The input is a list of files corresponding to each partition of a table. These files are

 * loaded into a temporary table (or into the final table if there is only one partition). The

 * output is a {@link KV} mapping each final table to a list of the temporary tables containing its

 * data.

 *

 * <p>In the case where all the data in the files fit into a single load job, this transform loads

 * the data directly into the final table, skipping temporary tables. In this case, the output

 * {@link KV} maps the final table to itself.

 */

class WriteTables<DestinationT>

    extends PTransform<

	PCollection<KV<ShardedKey<DestinationT>, List<String>>>,

		PCollection<KV<TableDestination, String>>> {

  private static final Logger LOG = LoggerFactory.getLogger(WriteTables.class);



  private final boolean tempTable;

  private final BigQueryServices bqServices;

  private final PCollectionView<String> loadJobIdPrefixView;

  private final BigQueryIO.Write.WriteDisposition firstPaneWriteDisposition;

  private final BigQueryIO.Write.CreateDisposition firstPaneCreateDisposition;

  private final DynamicDestinations<?, DestinationT> dynamicDestinations;

  private final List<PCollectionView<?>> sideInputs;

  private final TupleTag<KV<TableDestination, String>> mainOutputTag;

  private final TupleTag<String> temporaryFilesTag;

  private final ValueProvider<String> loadJobProjectId;

  private final int maxRetryJobs;

  private final boolean ignoreUnknownValues;

  @Nullable private final String kmsKey;



  private class WriteTablesDoFn

      extends DoFn<KV<ShardedKey<DestinationT>, List<String>>, KV<TableDestination, String>> {

    private Map<DestinationT, String> jsonSchemas = Maps.newHashMap();



    // Represents a pending BigQuery load job.

    private class PendingJobData {

      final BoundedWindow window;

      final BigQueryHelpers.PendingJob retryJob;

      final List<String> partitionFiles;

      final TableDestination tableDestination;

      final TableReference tableReference;



      public PendingJobData(

          BoundedWindow window,

          BigQueryHelpers.PendingJob retryJob,

          List<String> partitionFiles,

          TableDestination tableDestination,

          TableReference tableReference) {

        this.window = window;

        this.retryJob = retryJob;

        this.partitionFiles = partitionFiles;

        this.tableDestination = tableDestination;

        this.tableReference = tableReference;

      }

    }

    // All pending load jobs.

    private List<PendingJobData> pendingJobs = Lists.newArrayList();



    @StartBundle

    public void startBundle(StartBundleContext c) {

      // Clear the map on each bundle so we can notice side-input updates.

      // (alternative is to use a cache with a TTL).

      jsonSchemas.clear();

      pendingJobs.clear();

    }



    @ProcessElement

    public void processElement(ProcessContext c, BoundedWindow window) throws Exception {

      dynamicDestinations.setSideInputAccessorFromProcessContext(c);

      DestinationT destination = c.element().getKey().getKey();

      TableSchema tableSchema;

      if (firstPaneCreateDisposition == BigQueryIO.Write.CreateDisposition.CREATE_NEVER) {

        tableSchema = null;

      } else if (jsonSchemas.containsKey(destination)) {

        tableSchema =

            BigQueryHelpers.fromJsonString(jsonSchemas.get(destination), TableSchema.class);

      } else {

        tableSchema = dynamicDestinations.getSchema(destination);

        checkArgument(

            tableSchema != null,

            "Unless create disposition is %s, a schema must be specified, i.e. "

                + "DynamicDestinations.getSchema() may not return null. "

                + "However, create disposition is %s, and %s returned null for destination %s",

            BigQueryIO.Write.CreateDisposition.CREATE_NEVER,

            firstPaneCreateDisposition,

            dynamicDestinations,

            destination);

        jsonSchemas.put(destination, BigQueryHelpers.toJsonString(tableSchema));

      }



      TableDestination tableDestination = dynamicDestinations.getTable(destination);

      checkArgument(

          tableDestination != null,

          "DynamicDestinations.getTable() may not return null, "

              + "but %s returned null for destination %s",

          dynamicDestinations,

          destination);

      TableReference tableReference = tableDestination.getTableReference();

      if (Strings.isNullOrEmpty(tableReference.getProjectId())) {

        tableReference.setProjectId(c.getPipelineOptions().as(BigQueryOptions.class).getProject());

        tableDestination = tableDestination.withTableReference(tableReference);

      }



      Integer partition = c.element().getKey().getShardNumber();

      List<String> partitionFiles = Lists.newArrayList(c.element().getValue());

      String jobIdPrefix =

          BigQueryHelpers.createJobId(

              c.sideInput(loadJobIdPrefixView), tableDestination, partition, c.pane().getIndex());



      if (tempTable) {

        // This is a temp table. Create a new one for each partition and each pane.

        tableReference.setTableId(jobIdPrefix);

      }



      BigQueryIO.Write.WriteDisposition writeDisposition = firstPaneWriteDisposition;

      BigQueryIO.Write.CreateDisposition createDisposition = firstPaneCreateDisposition;

      if (c.pane().getIndex() > 0 && !tempTable) {

        // If writing directly to the destination, then the table is created on the first write

        // and we should change the disposition for subsequent writes.

        writeDisposition = BigQueryIO.Write.WriteDisposition.WRITE_APPEND;

        createDisposition = BigQueryIO.Write.CreateDisposition.CREATE_NEVER;

      } else if (tempTable) {

        // In this case, we are writing to a temp table and always need to create it.

        // WRITE_TRUNCATE is set so that we properly handle retries of this pane.

        writeDisposition = BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE;

        createDisposition = BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED;

      }



      BigQueryHelpers.PendingJob retryJob =

          startLoad(

              bqServices.getJobService(c.getPipelineOptions().as(BigQueryOptions.class)),

              bqServices.getDatasetService(c.getPipelineOptions().as(BigQueryOptions.class)),

              jobIdPrefix,

              tableReference,

              tableDestination.getTimePartitioning(),

              tableSchema,

              partitionFiles,

              writeDisposition,

              createDisposition);

      pendingJobs.add(

          new PendingJobData(window, retryJob, partitionFiles, tableDestination, tableReference));

    }



    @FinishBundle

    public void finishBundle(FinishBundleContext c) throws Exception {

      BigQueryServices.DatasetService datasetService =

          bqServices.getDatasetService(c.getPipelineOptions().as(BigQueryOptions.class));



      BigQueryHelpers.PendingJobManager jobManager = new BigQueryHelpers.PendingJobManager();

      for (PendingJobData pendingJob : pendingJobs) {

        jobManager =

            jobManager.addPendingJob(

                pendingJob.retryJob,

                // Lambda called when the job is done.

                j -> {

                  try {

                    if (pendingJob.tableDestination.getTableDescription() != null) {

                      TableReference ref = pendingJob.tableReference;

                      datasetService.patchTableDescription(

                          ref.clone()

                              .setTableId(

                                  BigQueryHelpers.stripPartitionDecorator(ref.getTableId())),

                          pendingJob.tableDestination.getTableDescription());

                    }

                    c.output(

                        mainOutputTag,

                        KV.of(

                            pendingJob.tableDestination,

                            BigQueryHelpers.toJsonString(pendingJob.tableReference)),

                        pendingJob.window.maxTimestamp(),

                        pendingJob.window);

                    for (String file : pendingJob.partitionFiles) {

                      c.output(

                          temporaryFilesTag,

                          file,

                          pendingJob.window.maxTimestamp(),

                          pendingJob.window);

                    }

                    return null;

                  } catch (IOException | InterruptedException e) {

                    return e;

                  }

                });

      }

      jobManager.waitForDone();

    }

  }



  private static class GarbageCollectTemporaryFiles extends DoFn<Iterable<String>, Void> {

    @ProcessElement

    public void processElement(ProcessContext c) throws Exception {

      removeTemporaryFiles(c.element());

    }

  }



  public WriteTables(

      boolean tempTable,

      BigQueryServices bqServices,

      PCollectionView<String> loadJobIdPrefixView,

      BigQueryIO.Write.WriteDisposition writeDisposition,

      BigQueryIO.Write.CreateDisposition createDisposition,

      List<PCollectionView<?>> sideInputs,

      DynamicDestinations<?, DestinationT> dynamicDestinations,

      @Nullable ValueProvider<String> loadJobProjectId,

      int maxRetryJobs,

      boolean ignoreUnknownValues,

      String kmsKey) {

    this.tempTable = tempTable;

    this.bqServices = bqServices;

    this.loadJobIdPrefixView = loadJobIdPrefixView;

    this.firstPaneWriteDisposition = writeDisposition;

    this.firstPaneCreateDisposition = createDisposition;

    this.sideInputs = sideInputs;

    this.dynamicDestinations = dynamicDestinations;

    this.mainOutputTag = new TupleTag<>("WriteTablesMainOutput");

    this.temporaryFilesTag = new TupleTag<>("TemporaryFiles");

    this.loadJobProjectId = loadJobProjectId;

    this.maxRetryJobs = maxRetryJobs;

    this.ignoreUnknownValues = ignoreUnknownValues;

    this.kmsKey = kmsKey;

  }



  @Override

  public PCollection<KV<TableDestination, String>> expand(

      PCollection<KV<ShardedKey<DestinationT>, List<String>>> input) {

    PCollectionTuple writeTablesOutputs =

        input.apply(

            ParDo.of(new WriteTablesDoFn())

                .withSideInputs(sideInputs)

                .withOutputTags(mainOutputTag, TupleTagList.of(temporaryFilesTag)));



    // Garbage collect temporary files.

    // We mustn't start garbage collecting files until we are assured that the WriteTablesDoFn has

    // succeeded in loading those files and won't be retried. Otherwise, we might fail part of the

    // way through deleting temporary files, and retry WriteTablesDoFn. This will then fail due

    // to missing files, causing either the entire workflow to fail or get stuck (depending on how

    // the runner handles persistent failures).

    writeTablesOutputs

        .get(temporaryFilesTag)

        .setCoder(StringUtf8Coder.of())

        .apply(WithKeys.of((Void) null))

        .setCoder(KvCoder.of(VoidCoder.of(), StringUtf8Coder.of()))

        .apply(

            Window.<KV<Void, String>>into(new GlobalWindows())

                .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1)))

                .discardingFiredPanes())

        .apply(GroupByKey.create())

        .apply(Values.create())

        .apply(ParDo.of(new GarbageCollectTemporaryFiles()));



    return writeTablesOutputs.get(mainOutputTag);

  }



  private BigQueryHelpers.PendingJob startLoad(

      BigQueryServices.JobService jobService,

      BigQueryServices.DatasetService datasetService,

      String jobIdPrefix,

      TableReference ref,

      TimePartitioning timePartitioning,

      @Nullable TableSchema schema,

      List<String> gcsUris,

      BigQueryIO.Write.WriteDisposition writeDisposition,

      BigQueryIO.Write.CreateDisposition createDisposition) {

    JobConfigurationLoad loadConfig =

        new JobConfigurationLoad()

            .setDestinationTable(ref)

            .setSchema(schema)

            .setSourceUris(gcsUris)

            .setWriteDisposition(writeDisposition.name())

            .setCreateDisposition(createDisposition.name())

            .setSourceFormat("NEWLINE_DELIMITED_JSON")

            .setIgnoreUnknownValues(ignoreUnknownValues);

    if (timePartitioning != null) {

      loadConfig.setTimePartitioning(timePartitioning);

    }

    if (kmsKey != null) {

      loadConfig.setDestinationEncryptionConfiguration(

          new EncryptionConfiguration().setKmsKeyName(kmsKey));

    }

    String projectId = loadJobProjectId == null ? ref.getProjectId() : loadJobProjectId.get();

    String bqLocation =

        BigQueryHelpers.getDatasetLocation(datasetService, ref.getProjectId(), ref.getDatasetId());



    BigQueryHelpers.PendingJob retryJob =

        new BigQueryHelpers.PendingJob(

            // Function to load the data.

            jobId -> {

              JobReference jobRef =

                  new JobReference()

                      .setProjectId(projectId)

                      .setJobId(jobId.getJobId())

                      .setLocation(bqLocation);

              LOG.info(

                  "Loading {} files into {} using job {}, job id iteration {}",

                  gcsUris.size(),

                  ref,

                  jobRef,

                  jobId.getRetryIndex());

              try {

                jobService.startLoadJob(jobRef, loadConfig);

              } catch (IOException | InterruptedException e) {

                LOG.warn("Load job {} failed with {}", jobRef, e.toString());

                throw new RuntimeException(e);

              }

              return null;

            },

            // Function to poll the result of a load job.

            jobId -> {

              JobReference jobRef =

                  new JobReference()

                      .setProjectId(projectId)

                      .setJobId(jobId.getJobId())

                      .setLocation(bqLocation);

              try {

                return jobService.pollJob(jobRef, BatchLoads.LOAD_JOB_POLL_MAX_RETRIES);

              } catch (InterruptedException e) {

                throw new RuntimeException(e);

              }

            },

            // Function to lookup a job.

            jobId -> {

              JobReference jobRef =

                  new JobReference()

                      .setProjectId(projectId)

                      .setJobId(jobId.getJobId())

                      .setLocation(bqLocation);

              try {

                return jobService.getJob(jobRef);

              } catch (InterruptedException | IOException e) {

                throw new RuntimeException(e);

              }

            },

            maxRetryJobs,

            jobIdPrefix);

    return retryJob;

  }



  static void removeTemporaryFiles(Iterable<String> files) throws IOException {

    ImmutableList.Builder<ResourceId> fileResources = ImmutableList.builder();

    for (String file : files) {

      fileResources.add(FileSystems.matchNewResource(file, false /* isDirectory */));

    }

    FileSystems.delete(fileResources.build());

  }

}